Skip to content

Fix WARP_SIZE detection for gfx942 (MI300X)#93

Open
kudomcho wants to merge 1 commit into
rocm_enabled_multi_backendfrom
fix/warp-size-gfx942
Open

Fix WARP_SIZE detection for gfx942 (MI300X)#93
kudomcho wants to merge 1 commit into
rocm_enabled_multi_backendfrom
fix/warp-size-gfx942

Conversation

@kudomcho

Copy link
Copy Markdown

Summary

  • Replace broken __GFX9__ guard with __AMDGCN_WAVEFRONT_SIZE (compiler-provided) in both csrc/kernels.hip and csrc/ops.hip
  • Default to WARP_SIZE=64 for CDNA when macro is unavailable
  • Fixes 4-bit GEMV inference kernel producing ~50% element mismatches on MI300X (gfx942)

Root Cause

__GFX9__ is not defined at compile time on recent ROCm builds, causing WARP_SIZE=32 on 64-wide wavefront gfx942. This affected:

  • ops.hip: grid launch computed num_blocks = (m+3)/4 instead of (m+1)/2, skipping half the output rows
  • kernels.hip: hipcub::WarpReduce template and lane/stride calculations used wrong warp width

Same root cause as ROCM-21835. Upstream fix in bitsandbytes-foundation/bitsandbytes PR bitsandbytes-foundation#1877 (merged 2026-02-24) unified CUDA/HIP sources with BNB_WARP_SIZE but was never synced to this fork.

Test plan

  • pytest -vvv ./tests/test_functional.py::test_gemv_eye_4bit — 6/6 passed (fp16/bf16/fp32 × nf4/fp4)
  • Full test_functional.py suite regression check

🤖 Generated with Claude Code

Replace broken __GFX9__ guard with __AMDGCN_WAVEFRONT_SIZE
(compiler-provided) and default to 64 for CDNA. The __GFX9__ macro
is not defined at compile time on recent ROCm, causing WARP_SIZE=32
on 64-wide wavefront gfx942 (MI300X). This broke the 4-bit GEMV
inference kernel grid launch and warp reduction, producing ~50%
element mismatches in test_gemv_eye_4bit.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@kudomcho

Copy link
Copy Markdown
Author

Test Results — test_gemv_eye_4bit on MI300X (gfx942)

Environment:

  • GPU: AMD Instinct MI300X (gfx942)
  • ROCm: 7.2.2
  • Python: 3.10.12
  • PyTorch: 2.7.0+rocm7.2
  • bitsandbytes: 0.43.3.dev0 (branch fix/warp-size-gfx942)

Before fix (baseline — rocm_enabled_multi_backend HEAD):

tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp16-nf4] FAILED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp16-fp4] FAILED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-bf16-nf4] FAILED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-bf16-fp4] FAILED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp32-nf4] FAILED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp32-fp4] FAILED

Error: AssertionError: Tensor-likes are not close!
       Mismatched elements: ~50% (e.g. 4096/8192)
       Greatest absolute difference: 0.396 (up to 1e-05 allowed)
       Greatest relative difference: inf

================== 6 failed in 87.08s ===================

After fix:

tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp16-nf4] PASSED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp16-fp4] PASSED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-bf16-nf4] PASSED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-bf16-fp4] PASSED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp32-nf4] PASSED
tests/test_functional.py::test_gemv_eye_4bit[DQ_True-fp32-fp4] PASSED

================== 6 passed, 74 warnings in 92.53s ===================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant